{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "pjFPwBDf3UGP"
      },
      "source": [
        "# Learning from Heterogeneous Graphs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fTWmy3gCJ_8a"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
        "\n",
        "torch.manual_seed(0)\n",
        "torch.cuda.manual_seed(0)\n",
        "torch.cuda.manual_seed_all(0)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Z0FxDj17J1ir"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "np.random.seed(0)\n",
        "\n",
        "import torch\n",
        "torch.manual_seed(0)\n",
        "from torch.nn import Linear\n",
        "from torch_geometric.nn import MessagePassing\n",
        "from torch_geometric.utils import add_self_loops, degree\n",
        "\n",
        "class GCNConv(MessagePassing):\n",
        "    def __init__(self, dim_in, dim_h):\n",
        "        super().__init__(aggr='add')\n",
        "        self.linear = Linear(dim_in, dim_h, bias=False)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
        "\n",
        "        x = self.linear(x)\n",
        "\n",
        "        row, col = edge_index\n",
        "        deg = degree(col, x.size(0), dtype=x.dtype)\n",
        "        deg_inv_sqrt = deg.pow(-0.5)\n",
        "        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n",
        "        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n",
        "\n",
        "        out = self.propagate(edge_index, x=x, norm=norm)\n",
        "\n",
        "        return out\n",
        "\n",
        "    def message(self, x, norm):\n",
        "        return norm.view(-1, 1) * x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "VOxgX0sQKCTx"
      },
      "outputs": [],
      "source": [
        "conv = GCNConv(16, 32)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "KqjK7PVZKtLg"
      },
      "source": [
        "## Heterogeneous graphs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EWnTph5t3SW5",
        "outputId": "01cd6b23-3850-4b0a-c79d-72cbab87bf45"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "HeteroData(\n",
              "  \u001b[1muser\u001b[0m={ x=[3, 4] },\n",
              "  \u001b[1mgame\u001b[0m={ x=[2, 2] },\n",
              "  \u001b[1mdev\u001b[0m={ x=[2, 1] },\n",
              "  \u001b[1m(user, follows, user)\u001b[0m={ edge_index=[2, 2] },\n",
              "  \u001b[1m(user, plays, game)\u001b[0m={\n",
              "    edge_index=[2, 4],\n",
              "    edge_attr=[4, 1]\n",
              "  },\n",
              "  \u001b[1m(dev, develops, game)\u001b[0m={ edge_index=[2, 2] }\n",
              ")"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from torch_geometric.data import HeteroData\n",
        "\n",
        "data = HeteroData()\n",
        "\n",
        "data['user'].x = torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]) # [num_users, num_features_users]\n",
        "data['game'].x = torch.Tensor([[1, 1], [2, 2]])\n",
        "data['dev'].x = torch.Tensor([[1], [2]])\n",
        "\n",
        "data['user', 'follows', 'user'].edge_index = torch.Tensor([[0, 1], [1, 2]]) # [2, num_edges_follows]\n",
        "data['user', 'plays', 'game'].edge_index = torch.Tensor([[0, 1, 1, 2], [0, 0, 1, 1]])\n",
        "data['dev', 'develops', 'game'].edge_index = torch.Tensor([[0, 1], [0, 1]])\n",
        "\n",
        "data['user', 'plays', 'game'].edge_attr = torch.Tensor([[2], [0.5], [10], [12]])\n",
        "\n",
        "data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6LM57bFo_HTl",
        "outputId": "48465ce5-8521-483e-85f1-bade4f9f7ed6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "HeteroData(\n",
            "  metapath_dict={ (author, metapath_0, author)=[2] },\n",
            "  \u001b[1mauthor\u001b[0m={\n",
            "    x=[4057, 334],\n",
            "    y=[4057],\n",
            "    train_mask=[4057],\n",
            "    val_mask=[4057],\n",
            "    test_mask=[4057]\n",
            "  },\n",
            "  \u001b[1mpaper\u001b[0m={ x=[14328, 4231] },\n",
            "  \u001b[1mterm\u001b[0m={ x=[7723, 50] },\n",
            "  \u001b[1mconference\u001b[0m={ num_nodes=20 },\n",
            "  \u001b[1m(author, metapath_0, author)\u001b[0m={ edge_index=[2, 11113] }\n",
            ")\n",
            "Epoch:   0 | Train Loss: 1.3834 | Train Acc: 27.00% | Val Acc: 29.25%\n",
            "Epoch:  20 | Train Loss: 1.2458 | Train Acc: 49.50% | Val Acc: 45.50%\n",
            "Epoch:  40 | Train Loss: 1.1287 | Train Acc: 65.50% | Val Acc: 58.00%\n",
            "Epoch:  60 | Train Loss: 1.0278 | Train Acc: 75.50% | Val Acc: 65.25%\n",
            "Epoch:  80 | Train Loss: 0.9415 | Train Acc: 79.75% | Val Acc: 68.00%\n",
            "Epoch: 100 | Train Loss: 0.8675 | Train Acc: 83.50% | Val Acc: 71.00%\n",
            "Test accuracy: 73.29%\n"
          ]
        }
      ],
      "source": [
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "import torch_geometric.transforms as T\n",
        "from torch_geometric.datasets import DBLP\n",
        "from torch_geometric.nn import GAT\n",
        "\n",
        "metapaths = [[('author', 'paper'), ('paper', 'author')]]\n",
        "transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True)\n",
        "dataset = DBLP('.', transform=transform)\n",
        "data = dataset[0]\n",
        "print(data)\n",
        "\n",
        "model = GAT(in_channels=-1, hidden_channels=64, out_channels=4, num_layers=1)\n",
        "\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "data, model = data.to(device), model.to(device)\n",
        "\n",
        "@torch.no_grad()\n",
        "def test(mask):\n",
        "    model.eval()\n",
        "    pred = model(data.x_dict['author'], data.edge_index_dict[('author', 'metapath_0', 'author')]).argmax(dim=-1)\n",
        "    acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()\n",
        "    return float(acc)\n",
        "\n",
        "for epoch in range(101):\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    out = model(data.x_dict['author'], data.edge_index_dict[('author', 'metapath_0', 'author')])\n",
        "    mask = data['author'].train_mask\n",
        "    loss = F.cross_entropy(out[mask], data['author'].y[mask])\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    if epoch % 20 == 0:\n",
        "        train_acc = test(data['author'].train_mask)\n",
        "        val_acc = test(data['author'].val_mask)\n",
        "        print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')\n",
        "\n",
        "test_acc = test(data['author'].test_mask)\n",
        "print(f'Test accuracy: {test_acc*100:.2f}%')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nFROVbIHhDoj",
        "outputId": "48e0bd72-d4b6-41f6-9d72-b0f3e99adbc9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "GraphModule(\n",
            "  (conv): ModuleDict(\n",
            "    (author__to__paper): GATConv((-1, -1), 64, heads=1)\n",
            "    (paper__to__author): GATConv((-1, -1), 64, heads=1)\n",
            "    (paper__to__term): GATConv((-1, -1), 64, heads=1)\n",
            "    (paper__to__conference): GATConv((-1, -1), 64, heads=1)\n",
            "    (term__to__paper): GATConv((-1, -1), 64, heads=1)\n",
            "    (conference__to__paper): GATConv((-1, -1), 64, heads=1)\n",
            "  )\n",
            "  (linear): ModuleDict(\n",
            "    (author): Linear(in_features=64, out_features=4, bias=True)\n",
            "    (paper): Linear(in_features=64, out_features=4, bias=True)\n",
            "    (term): Linear(in_features=64, out_features=4, bias=True)\n",
            "    (conference): Linear(in_features=64, out_features=4, bias=True)\n",
            "  )\n",
            ")\n",
            "\n",
            "\n",
            "\n",
            "def forward(self, x, edge_index):\n",
            "    x_dict = torch_geometric_nn_to_hetero_transformer_get_dict(x);  x = None\n",
            "    x__author = x_dict.get('author', None)\n",
            "    x__paper = x_dict.get('paper', None)\n",
            "    x__term = x_dict.get('term', None)\n",
            "    x__conference = x_dict.get('conference', None);  x_dict = None\n",
            "    edge_index_dict = torch_geometric_nn_to_hetero_transformer_get_dict(edge_index);  edge_index = None\n",
            "    edge_index__author__to__paper = edge_index_dict.get(('author', 'to', 'paper'), None)\n",
            "    edge_index__paper__to__author = edge_index_dict.get(('paper', 'to', 'author'), None)\n",
            "    edge_index__paper__to__term = edge_index_dict.get(('paper', 'to', 'term'), None)\n",
            "    edge_index__paper__to__conference = edge_index_dict.get(('paper', 'to', 'conference'), None)\n",
            "    edge_index__term__to__paper = edge_index_dict.get(('term', 'to', 'paper'), None)\n",
            "    edge_index__conference__to__paper = edge_index_dict.get(('conference', 'to', 'paper'), None);  edge_index_dict = None\n",
            "    conv__paper1 = self.conv.author__to__paper((x__author, x__paper), edge_index__author__to__paper);  edge_index__author__to__paper = None\n",
            "    conv__author = self.conv.paper__to__author((x__paper, x__author), edge_index__paper__to__author);  x__author = edge_index__paper__to__author = None\n",
            "    conv__term = self.conv.paper__to__term((x__paper, x__term), edge_index__paper__to__term);  edge_index__paper__to__term = None\n",
            "    conv__conference = self.conv.paper__to__conference((x__paper, x__conference), edge_index__paper__to__conference);  edge_index__paper__to__conference = None\n",
            "    conv__paper2 = self.conv.term__to__paper((x__term, x__paper), edge_index__term__to__paper);  x__term = edge_index__term__to__paper = None\n",
            "    conv__paper3 = self.conv.conference__to__paper((x__conference, x__paper), edge_index__conference__to__paper);  x__conference = x__paper = edge_index__conference__to__paper = None\n",
            "    conv__paper_1 = torch.add(conv__paper1, conv__paper2);  conv__paper1 = conv__paper2 = None\n",
            "    conv__paper = torch.add(conv__paper3, conv__paper_1);  conv__paper3 = conv__paper_1 = None\n",
            "    relu__author = conv__author.relu();  conv__author = None\n",
            "    relu__paper = conv__paper.relu();  conv__paper = None\n",
            "    relu__term = conv__term.relu();  conv__term = None\n",
            "    relu__conference = conv__conference.relu();  conv__conference = None\n",
            "    linear__author = self.linear.author(relu__author);  relu__author = None\n",
            "    linear__paper = self.linear.paper(relu__paper);  relu__paper = None\n",
            "    linear__term = self.linear.term(relu__term);  relu__term = None\n",
            "    linear__conference = self.linear.conference(relu__conference);  relu__conference = None\n",
            "    return {'author': linear__author, 'paper': linear__paper, 'term': linear__term, 'conference': linear__conference}\n",
            "    \n",
            "# To see more debug info, please use `graph_module.print_readable()`\n",
            "Epoch:   0 | Train Loss: 1.3864 | Train Acc: 31.25% | Val Acc: 25.75%\n",
            "Epoch:  20 | Train Loss: 1.1945 | Train Acc: 89.00% | Val Acc: 56.75%\n",
            "Epoch:  40 | Train Loss: 0.8576 | Train Acc: 94.50% | Val Acc: 65.75%\n",
            "Epoch:  60 | Train Loss: 0.5049 | Train Acc: 98.00% | Val Acc: 73.25%\n",
            "Epoch:  80 | Train Loss: 0.2687 | Train Acc: 99.25% | Val Acc: 76.75%\n",
            "Epoch: 100 | Train Loss: 0.1574 | Train Acc: 100.00% | Val Acc: 76.50%\n",
            "Test accuracy: 78.39%\n"
          ]
        }
      ],
      "source": [
        "from torch_geometric.nn import GATConv, Linear, to_hetero\n",
        "\n",
        "dataset = DBLP(root='.')\n",
        "data = dataset[0]\n",
        "\n",
        "data['conference'].x = torch.zeros(20, 1)\n",
        "\n",
        "class GAT(torch.nn.Module):\n",
        "    def __init__(self, dim_h, dim_out):\n",
        "        super().__init__()\n",
        "        self.conv = GATConv((-1, -1), dim_h, add_self_loops=False)\n",
        "        self.linear = nn.Linear(dim_h, dim_out)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        h = self.conv(x, edge_index).relu()\n",
        "        h = self.linear(h)\n",
        "        return h\n",
        "\n",
        "model = GAT(dim_h=64, dim_out=4)\n",
        "model = to_hetero(model, data.metadata(), aggr='sum')\n",
        "print(model)\n",
        "\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "data, model = data.to(device), model.to(device)\n",
        "\n",
        "@torch.no_grad()\n",
        "def test(mask):\n",
        "    model.eval()\n",
        "    pred = model(data.x_dict, data.edge_index_dict)['author'].argmax(dim=-1)\n",
        "    acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()\n",
        "    return float(acc)\n",
        "\n",
        "for epoch in range(101):\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    out = model(data.x_dict, data.edge_index_dict)['author']\n",
        "    mask = data['author'].train_mask\n",
        "    loss = F.cross_entropy(out[mask], data['author'].y[mask])\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    if epoch % 20 == 0:\n",
        "        train_acc = test(data['author'].train_mask)\n",
        "        val_acc = test(data['author'].val_mask)\n",
        "        print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')\n",
        "\n",
        "test_acc = test(data['author'].test_mask)\n",
        "print(f'Test accuracy: {test_acc*100:.2f}%')"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "xADuuEQseUrj"
      },
      "source": [
        "## Hierarchical Self-Attention Network (HAN)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0AcupaiIddjz",
        "outputId": "63bece01-a990-440e-f5d5-002dced07c83"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "HeteroData(\n",
            "  \u001b[1mauthor\u001b[0m={\n",
            "    x=[4057, 334],\n",
            "    y=[4057],\n",
            "    train_mask=[4057],\n",
            "    val_mask=[4057],\n",
            "    test_mask=[4057]\n",
            "  },\n",
            "  \u001b[1mpaper\u001b[0m={ x=[14328, 4231] },\n",
            "  \u001b[1mterm\u001b[0m={ x=[7723, 50] },\n",
            "  \u001b[1mconference\u001b[0m={ num_nodes=20 },\n",
            "  \u001b[1m(author, to, paper)\u001b[0m={ edge_index=[2, 19645] },\n",
            "  \u001b[1m(paper, to, author)\u001b[0m={ edge_index=[2, 19645] },\n",
            "  \u001b[1m(paper, to, term)\u001b[0m={ edge_index=[2, 85810] },\n",
            "  \u001b[1m(paper, to, conference)\u001b[0m={ edge_index=[2, 14328] },\n",
            "  \u001b[1m(term, to, paper)\u001b[0m={ edge_index=[2, 85810] },\n",
            "  \u001b[1m(conference, to, paper)\u001b[0m={ edge_index=[2, 14328] }\n",
            ")\n",
            "Epoch:   0 | Train Loss: 1.3829 | Train Acc: 49.75% | Val Acc: 37.75%\n",
            "Epoch:  20 | Train Loss: 1.1551 | Train Acc: 86.50% | Val Acc: 60.75%\n",
            "Epoch:  40 | Train Loss: 0.7695 | Train Acc: 94.00% | Val Acc: 67.50%\n",
            "Epoch:  60 | Train Loss: 0.4750 | Train Acc: 97.75% | Val Acc: 73.75%\n",
            "Epoch:  80 | Train Loss: 0.3008 | Train Acc: 99.25% | Val Acc: 78.25%\n",
            "Epoch: 100 | Train Loss: 0.2247 | Train Acc: 99.50% | Val Acc: 78.75%\n",
            "Test accuracy: 81.58%\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from torch import nn\n",
        "\n",
        "import torch_geometric.transforms as T\n",
        "from torch_geometric.datasets import DBLP\n",
        "from torch_geometric.nn import HANConv, Linear\n",
        "\n",
        "\n",
        "dataset = DBLP('.')\n",
        "data = dataset[0]\n",
        "print(data)\n",
        "\n",
        "data['conference'].x = torch.zeros(20, 1)\n",
        "\n",
        "class HAN(nn.Module):\n",
        "    def __init__(self, dim_in, dim_out, dim_h=128, heads=8):\n",
        "        super().__init__()\n",
        "        self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())\n",
        "        self.linear = nn.Linear(dim_h, dim_out)\n",
        "\n",
        "    def forward(self, x_dict, edge_index_dict):\n",
        "        out = self.han(x_dict, edge_index_dict)\n",
        "        out = self.linear(out['author'])\n",
        "        return out\n",
        "\n",
        "model = HAN(dim_in=-1, dim_out=4)\n",
        "\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "data, model = data.to(device), model.to(device)\n",
        "\n",
        "@torch.no_grad()\n",
        "def test(mask):\n",
        "    model.eval()\n",
        "    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)\n",
        "    acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()\n",
        "    return float(acc)\n",
        "\n",
        "for epoch in range(101):\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    out = model(data.x_dict, data.edge_index_dict)\n",
        "    mask = data['author'].train_mask\n",
        "    loss = F.cross_entropy(out[mask], data['author'].y[mask])\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    if epoch % 20 == 0:\n",
        "        train_acc = test(data['author'].train_mask)\n",
        "        val_acc = test(data['author'].val_mask)\n",
        "        print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')\n",
        "\n",
        "test_acc = test(data['author'].test_mask)\n",
        "print(f'Test accuracy: {test_acc*100:.2f}%')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "book",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15"
    },
    "vscode": {
      "interpreter": {
        "hash": "3556630122da5213751af4465d61fcf5a52cd22515d400aee51118aaa1721248"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
